{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# imports\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "# transforms\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5,), (0.5,))])\n",
    "\n",
    "# datasets\n",
    "trainset = torchvision.datasets.FashionMNIST('./data',\n",
    "                                             download=True,\n",
    "                                             train=True,\n",
    "                                             transform=transform)\n",
    "testset = torchvision.datasets.FashionMNIST('./data',\n",
    "                                            download=True,\n",
    "                                            train=False,\n",
    "                                            transform=transform)\n",
    "\n",
    "# dataloaders\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n",
    "                                          shuffle=True, num_workers=2)\n",
    "\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n",
    "                                         shuffle=False, num_workers=2)\n",
    "\n",
    "# constant for classes\n",
    "classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
    "           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')\n",
    "\n",
    "\n",
    "# helper function to show an image\n",
    "# (used in the `plot_classes_preds` function below)\n",
    "def matplotlib_imshow(img, one_channel=False):\n",
    "    if one_channel:\n",
    "        img = img.mean(dim=0)\n",
    "    img = img / 2 + 0.5  # unnormalize\n",
    "    npimg = img.numpy()\n",
    "    if one_channel:\n",
    "        plt.imshow(npimg, cmap=\"Greys\")\n",
    "    else:\n",
    "        plt.imshow(np.transpose(npimg, (1, 2, 0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 4 * 4, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 4 * 4)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "net = Net()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 1. TensorBoard 设置\n",
    "现在我们设置 TensorBoard，从 torch.utils 导入 tensorboard 并定义 SummaryWriter ，这是我们用于将信息写入到的关键对象tensor板。"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "writer = SummaryWriter('run.fashion_mnist_experiment_1')\n",
    "# 请注意，这一行单独创建了一个 runs/fashion_mnist_experiment_1 文件夹。"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 写入TensorBoard\n",
    "将图像写入我们的TensorBoard。使用make_grid."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "# get some random training images\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = next(dataiter)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "# create grid of images\n",
    "img_grid = torchvision.utils.make_grid(images)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "data": {
      "text/plain": "<Figure size 640x480 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiYAAACxCAYAAADwMnaUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAl+ElEQVR4nO3de1RU5f4/8DeogIagYNxUlE4XNS+ZFyJbfrtQ5uqkpZWVJZVndSo00U6plboqO2RXu5h2Ti09rTLLtdTSdbQMC7PwhlreIktTFIHUuIiCBPv3R8f5+byHZjMyxJ7h/VqLtXrPbPZsnr33+DTPZ54nyLIsCyIiIiIOENzUByAiIiJymjomIiIi4hjqmIiIiIhjqGMiIiIijqGOiYiIiDiGOiYiIiLiGOqYiIiIiGOoYyIiIiKOoY6JiIiIOIY6JiIiIuIYjdYxmTNnDrp27YqwsDAkJydj48aNjfVSIiIiEiCCGmOtnA8//BBjxozBvHnzkJycjNmzZ2Px4sXIy8tDTEyMx9+tra1FQUEB2rZti6CgIF8fmoiIiDQCy7JQXl6OhIQEBAef/ecejdIxSU5OxoABA/DGG28A+L2z0blzZ4wfPx5Tpkzx+LsHDx5E586dfX1IIiIi8ifIz89Hp06dzvr3W/rwWAAAp06dQm5uLqZOnep6LDg4GKmpqcjJyXHbvqqqClVVVa58up80c+ZMhIWF+frwREREpBFUVlbiySefRNu2bRu0H593TI4cOYKamhrExsYaj8fGxuL777932z4zMxNPPfWU2+NhYWFo3bq1rw9PREREGlFDyzCa/Fs5U6dORWlpqesnPz+/qQ9JREREmojPPzHp0KEDWrRogaKiIuPxoqIixMXFuW0fGhqK0NBQXx+GiIiI+CGff2ISEhKCfv36ISsry/VYbW0tsrKykJKS4uuXExERkQDi809MAGDSpElIS0tD//79MXDgQMyePRsVFRW49957G+PlREREJEA0Ssdk1KhR+OWXXzB9+nQUFhbikksuwapVq9wKYs/WQw895JP9SNN68803PT7vD+e5trbWyPzd/ZqaGiN//fXXHre/4oorPL4ef7vfH+b68cfzvGjRIiO/9957Rt66dauRb731ViMPHz7cyKdOnTLyZ599ZuSVK1ca+Z577jHyY4895vmAHcAfzzP74IMPjLxs2TIj83QX/G9aRUWFkRMSEoz81VdfGfn0lBqnzZs3z8gN+cptY7E7z77QKB0TABg3bhzGjRvXWLsXERGRANTk38oREREROU0dExEREXGMRhvKEQlEXONhtx7E5MmTjTxjxgwjFxcXG/lvf/ubkd9++20jc02JP9acOFF8fLyRebqDdu3aGblVq1ZGfvXVVz1mO9HR0UaeOXOmkV9++WUjr1692si9evXy6vWaq8zMTCPz/XXy5Ekj8ySfd955p5HDw8ONPGDAACMvWbLEyFwzUlBQYOTLL7/cyFFRUUZ+5plnjHzjjTciEOkTExEREXEMdUxERETEMdQxEREREcdQjYmIF3jekhYtWhh57969Rub5LXjVTc49evQw8g8//GDkCy+80MiqMTk7lZWVRubags6dOxv5zBXQAfd279Kli5HtrhOuTTpx4oSR+booKSkx8uOPP27k5cuXQ0zHjx93e4xrdbp27WpkPq+//fabkfm88nn75ZdfPB4T74+XaeHX5+s0IyPDyKoxEREREWlk6piIiIiIY6hjIiIiIo6hGhMRL9jVcCxdutTI48eP92r/vP2qVauMzDUmcnaeeOIJI/OaRlwLwGP/jGtUuBbBDs+LUl1dbWSez+Knn37yav/NEddjAECbNm08/k55ebmReY0jrg3i9wNeK6dlS/OfWD6vvD3Pl8P4OuX5bp588kmPv+8v9ImJiIiIOIY6JiIiIuIY6piIiIiIY6hjIiIiIo6h4lcRD7iIkYvfuJiNi9NCQkK8ej0ugvz111+9Oh5NuFY/mzdvNjK3G2eeII0nvuLF3uyKX3n/XGTJi8PxxF1cNMlF0tdff73H128OSktL3R4LCwszMhc52xW38v1sd7+GhoYama8LLnbl64Bfnydk27VrFwKRPjERERERx1DHRERERBxDHRMRERFxDNWYiHhgN7HWjh07jOxtTYddTcjRo0eNvH79eiNffvnlRuYaF57gSX6XnZ1tZF6kb9++fUbu3r27kXmRP8btzueZn+dF+lJTU4187bXXGplrG1RT4m7x4sVuj7Vv397IERERRrZbbJHbne+3wsJCI3PtEdeo8HXH++P7n69Lfj8IFPrERERERBxDHRMRERFxDHVMRERExDE0AC3igV2NyaZNm4zcq1cvr/ZnV2Ny2WWXGfnzzz83MteY2B2v1I3nm+jWrZvH7e3mu7DDtQdcY5KYmGjkUaNGebV/qVtubq6RBw8ebGSuQeHFE7kGhOdKOe+884x84sQJI3ONCc+HU1xcbGS+zg4fPozmQJ+YiIiIiGOoYyIiIiKOoY6JiIiIOIZqTBzKrlbAbkz77bffNvLQoUON3LFjR9tj2LNnj8ffadOmjcff98d1W+zmm2B79+418pAhQ7zav12b8Ji13Rgzz7Pgj+fgz9DQduE1Vvg64f3x6/H8FSwpKcnj81zr0NCal+aC76eZM2caecqUKUbmmhNeyyY6OtrIdueV50mxu254jaTmQp+YiIiIiGOoYyIiIiKO4XXHZO3atbjxxhuRkJCAoKAgLFu2zHjesixMnz4d8fHxaN26NVJTU92GBERERETq4nWNSUVFBfr06YP77rsPI0aMcHv++eefx2uvvYb//Oc/SEpKwrRp0zBkyBDs2rULYWFhPjnoQMBjzrW1tUbmsUjGY9xc2/Ddd98Zmdd0mT17tu0x7ty508g//fSTkbnmhOfw8Mdxbm9rD3hMuUuXLh6397ZNYmNjjXzOOecYmce8eZ4Eb68rqRvXfHC729WUMK4J4e0vvPBCj7/Pr+eP91pjq+sccDtxrQ6/r548edLj75eVlRmZa0Z4f3zd8DwnfL/y+wvPt8Pb83Xlr7zumAwdOtStkPI0y7Iwe/ZsPPnkkxg+fDgA4N1330VsbCyWLVuG22+/vWFHKyIiIgHNp92rffv2obCw0FgZMzIyEsnJycjJyanzd6qqqlBWVmb8iIiISPPk047J6SWf+ePn2NhYt+WgT8vMzERkZKTrx245cREREQlcTT6PydSpUzFp0iRXLisrC4jOid3YH49V2o39P/roo0Z+8cUXjdy9e3cj8/fv6zNvCbvpppuM/OOPPxr50ksvNfLu3bs9vmYgrOPy888/G9luLhfW0FqAtm3bGpnbvE+fPj59vUDlbS0R1w7YzVfBtQB2r8f3P9+/4r361JjEx8cbmesg7c473/98HrmGhTNv722NSKDe3z79xCQuLg4AUFRUZDxeVFTkeo6FhoYiIiLC+BEREZHmyacdk6SkJMTFxSErK8v1WFlZGTZs2ICUlBRfvpSIiIgEIK+Hco4fP258pL9v3z5s27YNUVFRSExMREZGBmbOnIkLLrjA9XXhhIQEt2EBEREREeZ1x2Tz5s246qqrXPl0fUhaWhoWLFiAxx57DBUVFbj//vtRUlKCK664AqtWrfL7OUy8HZO2GyvkYuBvvvnGyBkZGUY+fvy4kXv27GlkHtM+cuSIkZ944gkj8/fzAbh9nTs/P9/IkydPNnKHDh2MvGDBAo+v2dTjob6oceHJAsPDwxu0P2+vK57HZN++fUb2RY1Jc1hfx+7+rK6uNjLfrzxfDNcO8Bon3KZcQ8LfRuTzyvOaBMp8FU2N79/KykqPmf8d4/PA85LwvCX8enbznPDaVyxQ71WvOyZXXnmlxzf4oKAgPP3003j66acbdGAiIiLS/KjbLSIiIo6hjomIiIg4RpPPY+ILPLRkl+szPtvQtSi4JuTtt9828iOPPGLkqKgoI5977rlGTkhIMHJBQYGR+fv2vGYLz6syY8YMt2Ou67Ez8fwyAwYMMPKuXbuM/GePf9q9Xn1e326bDz74wMg9evRo9GM6E9c68JpIXGR+Nm0eKOPUntidl48//tjIF198sZG5RoTnt7CrDeDn+fc/+eQTI/NaWGKvPtex3dpTfJ74uikuLjZyTEyMkbn2r7y83Mh83nn/fDz+XqtZX/rERERERBxDHRMRERFxDHVMRERExDECosakoeP2Z4PnHZgwYYKReU6PM+d+AdzrM/j78jzPyK+//mrk1q1bG5kXTuQ1XXjOkaSkJDCuX+DX4DkyuFbnyy+/NDLXwZzNej2eNMV3+HnegWHDhnnc3tfHNHDgQCO//vrrPt1/c2F3Xt566y0jDx482Mi8dAbfr1xbwPOe8L3DNSsLFy40Mtci8P7t1uaSukVHRxuZ25Hfl3mtKs78nse4lpD3z2vn8HXDArUeTFeviIiIOIY6JiIiIuIY6piIiIiIYwREjQnj2gP+7jivZ7Bz5063fXz99ddG/u9//2vk7du3GzkxMdHIgwYNMjKvncHff+exTsZjjbwWzt13323k/fv3G5nHrHmeE8C97oTHzXntHP6b7GpUfI3HV9evX2/k5cuXG5nrdPbu3eu2zyuvvNLIPObLNSavvvqqkfk88rVmVxfDa7TEx8cbmc9rmzZtjDxz5kwj8zni4wHca39KS0uNzPUP06dPd9uHv+MaDp43iO+/oqIiI/M8QhUVFUbm64jPO19X3OZcM3bRRRcZ2RfrQDVHOTk5RubzwtcFXwf8PNce8XXB23Pm6+bYsWNG5lrBQKVPTERERMQx1DERERERx1DHRERERBzDL2tM+LvmN998s8fta2pqjMw1JzwnCeC+RgHXS/Tu3dvja3KtAI8x8xoNfIw81njXXXcZeffu3UZ+6qmnjDxlyhQj33bbbUZ+77333I6ZX5PHS7lNeLyVz8uqVauMfOedd7q9pi/x+PCBAweMbLcOBQCsWLHCyDxfBGeuNeDX4PPKY858XfD2vP4Q47/5m2++8Xi8dc1vweeNa2/S0tI8HkMgOHTokJH5PcHu/uV25fPC552Fh4cb+fzzzzcy1wpxjYm4q88cH08//bSReV4Sfo/gNdB4HhK7Wh+7Y+L98ft0RkaGx98PFPrERERERBxDHRMRERFxDHVMRERExDHUMRERERHH8MviV14g76uvvjJycnKykbmgkAsCueCpLnbFqYyL3bgYtlWrVkYePXq0kb/99lsjv/POO0b++9//bmQujps1a5aRuYCRJxID3Ccb4kng7Ar6uB3tFqDyNT7vS5YsMfLRo0eNXJ/zztvw5GO8KBe3ERdFcqEpF1XydcJtXFJSYmS+ltu1a+fxeOuaoIkL7tg999zj8flAsHbtWiPz/cmZi+H5/uPzysWtXFTJ5533b/d+I2eH29mumNVukT0+T1wkzeed9x8ZGWnkuXPnGpmLX7WIn4iIiEgjU8dEREREHEMdExEREXEMv6wx4XF1Hpfjycd43J0XyKoLj93x2CIv1sZjxPwavAAej0Vy3UxKSoqRebKwxx9/3Mh//etfPR7f4sWLjcw1KoD7wmVcD9GjRw8j9+nTx8iDBw82Mi84x9nXuE6AF2fkSatyc3Pd9nHhhRca2W6RLcY1IdyGnOua8OxMXNtU16RwZ+IJoHgMnBcyBNyPmWtxeAHIxp4orynwtcOLI/J1wLUDfP/z81yrxPVbvLgi17QcPny4rsN2CdRaA1/jyTX5fZnPG7cr3/98P/N55ue5hoX/XeF/2/g6bC70iYmIiIg4hjomIiIi4hjqmIiIiIhj+GWNCc/5wbUDvHjcunXrjPzTTz8Z2W6cH3Af87WrJeBFwXjsn/HYJc+P8dlnnxmZa0i6dOliZF6A77fffjMyj20C7n+jHbtF8XiBu82bN3u1f2+NGTPGyMOGDTMyL8zG47mAe00G/012Y77cJt7WlNjtj8/rsWPHjMz1VPw31jWPSUFBgZG5hoQXkAzEGhOuPejYsaOR7ea34GsrPj7eyHwe+H7k64zPs12NibfXVXPF77N8Hu3mXuJaI/53gK8jnr+GryOuj2Q8P05D30/8RWD+VSIiIuKXvOqYZGZmYsCAAWjbti1iYmJw0003IS8vz9imsrIS6enpiI6ORnh4OEaOHOn2f84iIiIidfGqY5KdnY309HSsX78eq1evRnV1Na677jrj47GJEydi+fLlWLx4MbKzs1FQUIARI0b4/MBFREQk8HhVY8K1GwsWLEBMTAxyc3MxePBglJaW4p133sHChQtx9dVXAwDmz5+P7t27Y/369bjssst8d+Rn4Pk0eC6KRx55xMg8Xstr7QDAli1bjMxjhzy2x2OJXAPCfztnrg2ww/Nb8Ngnj3nbreECuI9z8zg4j8+ePHnSyDw3A4+H3nbbbUaeP3++2zF4g9tg2rRpRh47dqyRT1+TnvB8MZ06dTIyzzfBtQA8Zs2ZzxPjv8luTJlrXvi88/Y//vij22vyp57Tp0838sMPP2xku3WfAgGvoRIWFmZk/pv5/YFrD7iWgO8tu/W8uNZAzg7XePB54PvJ7rzydcG1e3Zr5/D9zdcdvz7vrz7rffmjBtWYnF4g7HShZm5uLqqrq5Gamuraplu3bkhMTHRbIE5ERESEnfW3cmpra5GRkYFBgwahZ8+eAIDCwkKEhIS4/d9/bGwsCgsL69xPVVWV8X+R/H98IiIi0nyc9Scm6enp2LFjBxYtWtSgA8jMzERkZKTrp3Pnzg3an4iIiPivs/rEZNy4cVixYgXWrl1rjMHHxcXh1KlTKCkpMT41KSoqQlxcXJ37mjp1KiZNmuTKZWVlXndOeJyPM68Pws+PGjXKbZ88JwaPLXqLxxJ57JBrG3h+CrvaBD4+Huvk53lsFXAf5+a6FN4Hz7XC88nweKmvcZtdd911Rua5Y7jup645PXiMmWsJuGaEa06Y3fwxnPk6YVzXw+eMj5frQer6m/kTziNHjhi5W7duRrZbr8cf8Xni88ztztsfPHjQyDwfxulh7z/aH9cC8Xms634V7/E3RPn+4/dZbnc+L3xe7eaP4vuVn+f7n+9fruNTjQl+v1nHjRuHpUuXYs2aNUhKSjKe79evH1q1aoWsrCzXY3l5eThw4IDbonSnhYaGIiIiwvgRERGR5smrjwHS09OxcOFCfPzxx2jbtq2rbiQyMhKtW7dGZGQkxo4di0mTJiEqKgoREREYP348UlJSGu0bOSIiIhI4vOqYzJ07FwBw5ZVXGo/Pnz8f99xzDwDglVdeQXBwMEaOHImqqioMGTIEb775pk8OVkRERAKbVx0THnetS1hYGObMmYM5c+ac9UF5i4+LayPat2/vMddVJ8Dj7HZjg3ZjhVyfwfUXdvOg8NgmZ/6bnag+1483uM15aJFri3h8NjY21m2fXJfCY9BcC8Bj0nyd8O/bPc/4ebu1PPg64PkweMwacK936Nq1q5GPHj1q5J9//tnI/fv393hMTsT3p13tD9+vfO3xewi/x/D9z9cBZ75XNI+Jb3CNid17Er/P8v1ndx3w/u3et3l7fr/h6yRQaa0cERERcQx1TERERMQx1DERERERx2jY5BwO0dD6Cp674o8ek4bxdR0M1+UwXpdi586dRq7rHHN9Bdec8Bgz12zwGDHP+cG1C3b1Ud7WNjG7Woq6XpPrm/j5jz76yMj+WGPCa47wWD7PT8Fz9nD9EtcW8FxJfB65ho3nkuGalD+aOVu8ExMTY2S+f/m88fN8v/G9wb/PNWh8P9rNFdVc56/RJyYiIiLiGOqYiIiIiGOoYyIiIiKOERA1JtI88Vo5U6dONfIbb7xh5GuuucbI+/fvd9sn1wLwGk92axbZrY3Bv8+1DXZ1ODwGzdnu9+sas+a6E669eemll4xcWVnp8TX8Ac/NwvO9cA0K1xbxWlc8FwzPV8HnmWtGuDbhzDXI6tq/nB27eYDsakI48/3N9wbXqPC9xvcr16Tx69m9/wQKfWIiIiIijqGOiYiIiDiGOiYiIiLiGKoxEb+VnJxs5L59+xqZx5N57om65jHZvn27kePj443MtQjMbt4SuzVS7GpEeN4T3j/XNrRt29bIhw8fdtun3RwZjzzyiJHrWm/H3/A8JXwt5OfnG7ljx45G3rhxo5H52rI777ze0Pnnn29kXpuHa1rs5r+RunGtD+N25MzvKfx+wLVA/Hp2tUeMa0zqWtctEOkTExEREXEMdUxERETEMdQxEREREcdQjYkEDB7/5XlOeJye1ycB3Oct4XlCeJ4C3mdD17awm8fArr6Dx7y5toGPFwASEhK8OUS3cXJ/xPOU8LpLXDPSoUMHI5eVlRmZ25n3d+zYMSPzekRck3Lo0CEjX3LJJUbm6ywQzsmfwW4tGrvaHa7x4BoRrlXieUe4BoWPh3+fn8/LyzMy19kFCn1iIiIiIo6hjomIiIg4hjomIiIi4hjqmIiIiIhjqPhV/BYXhnIh2pIlS4zMk41xwSHgXtzKr8ETX9ktwsXFdXbFsXZFjHUd85m4CJPbpK5FwHr37u1xn8yu3f3Bhg0bjMztEhsba2QurOa/mSdI48LqgwcPGvn48eMej4+vVb6uuIg5KirK4/7kd3z/cLvydcD3ExeX8+/zxH18P/N1weeRi9v5/Wjnzp1oDvzvHUVEREQCljomIiIi4hjqmIiIiIhjqMZE/JZdbcOYMWOMvGjRIiPXtSBWSUmJkT/66CPb33ESrnk577zzjMwThwHAxRdf7HGfgVBTwq666iojv/vuu0betWuXkW+++WYjHz161OP2R44cMXJOTo6Ri4uLjcwTtnGtAf/+L7/8YmTVmNQPL8qZkpLicXuuOeH7h2tO+LrgCdL2799v5KSkJCNzjcoNN9xg5MzMTI/HGyj8/x1GREREAoY6JiIiIuIY6piIiIiIY6jGRAIWj7uvX7/eyJ988onb7+zevdvIvKhWYmKikcPCwozMi4DxvAk8LwHXEvC8CTzvCW/P8yjYLQJYV41JRkaG22OBjhfFy83NNTK32/bt243Mc+QcPnzYyHweX3zxRSPPmjXL4+t37tzZyD/++KORuZZI6ofnEVm3bp2R+bxnZ2cbmevU+Dy3b9/eyP369TMyv19s2rTJyFwDM27cODRH+sREREREHMOrjsncuXPRu3dvREREICIiAikpKVi5cqXr+crKSqSnpyM6Ohrh4eEYOXIkioqKfH7QIiIiEpi86ph06tQJzz33HHJzc7F582ZcffXVGD58uGua3IkTJ2L58uVYvHgxsrOzUVBQgBEjRjTKgYuIiEjgCbJ4UNxLUVFReOGFF3DLLbfg3HPPxcKFC3HLLbcAAL7//nt0794dOTk5uOyyy+q1v7KyMkRGRuLFF190W5dEREREnOnkyZP4xz/+gdLSUkRERJz1fs66xqSmpgaLFi1CRUUFUlJSkJubi+rqaqSmprq26datGxITE90mBzpTVVUVysrKjB8RERFpnrzumGzfvh3h4eEIDQ3FAw88gKVLl6JHjx4oLCxESEiIW9VzbGwsCgsL/3B/mZmZiIyMdP1wNbqIiIg0H153TC666CJs27YNGzZswIMPPoi0tDS36Zi9MXXqVJSWlrp+8vPzz3pfIiIi4t+8nsckJCQE559/PoDfv6O9adMmvPrqqxg1ahROnTqFkpIS41OToqIixMXF/eH+QkND3dYbEBERkeapwfOY1NbWoqqqCv369UOrVq2QlZXlei4vLw8HDhywXShJREREBPDyE5OpU6di6NChSExMRHl5ORYuXIgvv/wSn376KSIjIzF27FhMmjQJUVFRiIiIwPjx45GSklLvb+SIiIhI8+ZVx6S4uBhjxozB4cOHERkZid69e+PTTz/FtddeCwB45ZVXEBwcjJEjR6KqqgpDhgzBm2++6dUBnf72Mk/dLSIiIs51+t/tBs5C0vB5THzt4MGD+maOiIiIn8rPz0enTp3O+vcd1zGpra1FQUEBLMtCYmIi8vPzGzRRS3NXVlaGzp07qx0bQG3YcGpD31A7NpzasOH+qA0ty0J5eTkSEhIQHHz2JayOW104ODgYnTp1ck20dnpdHmkYtWPDqQ0bTm3oG2rHhlMbNlxdbRgZGdng/Wp1YREREXEMdUxERETEMRzbMQkNDcWMGTM0+VoDqR0bTm3YcGpD31A7NpzasOEauw0dV/wqIiIizZdjPzERERGR5kcdExEREXEMdUxERETEMdQxEREREcdwbMdkzpw56Nq1K8LCwpCcnIyNGzc29SE5VmZmJgYMGIC2bdsiJiYGN910E/Ly8oxtKisrkZ6ejujoaISHh2PkyJEoKipqoiN2vueeew5BQUHIyMhwPaY2rJ9Dhw7hrrvuQnR0NFq3bo1evXph8+bNructy8L06dMRHx+P1q1bIzU1FXv27GnCI3aWmpoaTJs2DUlJSWjdujX+8pe/4JlnnjHWH1EbmtauXYsbb7wRCQkJCAoKwrJly4zn69Nex44dw+jRoxEREYF27dph7NixOH78+J/4VzQ9T+1YXV2NyZMno1evXjjnnHOQkJCAMWPGoKCgwNiHL9rRkR2TDz/8EJMmTcKMGTOwZcsW9OnTB0OGDEFxcXFTH5ojZWdnIz09HevXr8fq1atRXV2N6667DhUVFa5tJk6ciOXLl2Px4sXIzs5GQUEBRowY0YRH7VybNm3CW2+9hd69exuPqw3t/frrrxg0aBBatWqFlStXYteuXXjppZfQvn171zbPP/88XnvtNcybNw8bNmzAOeecgyFDhmjhzv+ZNWsW5s6dizfeeAO7d+/GrFmz8Pzzz+P11193baM2NFVUVKBPnz6YM2dOnc/Xp71Gjx6NnTt3YvXq1VixYgXWrl2L+++//8/6ExzBUzueOHECW7ZswbRp07BlyxYsWbIEeXl5GDZsmLGdT9rRcqCBAwda6enprlxTU2MlJCRYmZmZTXhU/qO4uNgCYGVnZ1uWZVklJSVWq1atrMWLF7u22b17twXAysnJaarDdKTy8nLrggsusFavXm393//9nzVhwgTLstSG9TV58mTriiuu+MPna2trrbi4OOuFF15wPVZSUmKFhoZaH3zwwZ9xiI53ww03WPfdd5/x2IgRI6zRo0dblqU2tAPAWrp0qSvXp7127dplAbA2bdrk2mblypVWUFCQdejQoT/t2J2E27EuGzdutABY+/fvtyzLd+3ouE9MTp06hdzcXKSmproeCw4ORmpqKnJycprwyPxHaWkpACAqKgoAkJubi+rqaqNNu3XrhsTERLUpSU9Pxw033GC0FaA2rK9PPvkE/fv3x6233oqYmBj07dsX//73v13P79u3D4WFhUY7RkZGIjk5We34P5dffjmysrLwww8/AAC+/fZbrFu3DkOHDgWgNvRWfdorJycH7dq1Q//+/V3bpKamIjg4GBs2bPjTj9lflJaWIigoCO3atQPgu3Z03CJ+R44cQU1NDWJjY43HY2Nj8f333zfRUfmP2tpaZGRkYNCgQejZsycAoLCwECEhIa6L57TY2FgUFhY2wVE606JFi7BlyxZs2rTJ7Tm1Yf3s3bsXc+fOxaRJk/D4449j06ZNePjhhxESEoK0tDRXW9V1f6sdfzdlyhSUlZWhW7duaNGiBWpqavDss89i9OjRAKA29FJ92quwsBAxMTHG8y1btkRUVJTa9A9UVlZi8uTJuOOOO1wL+fmqHR3XMZGGSU9Px44dO7Bu3bqmPhS/kp+fjwkTJmD16tUICwtr6sPxW7W1tejfvz/++c9/AgD69u2LHTt2YN68eUhLS2vio/MPH330Ed5//30sXLgQF198MbZt24aMjAwkJCSoDcURqqurcdttt8GyLMydO9fn+3fcUE6HDh3QokULt287FBUVIS4uromOyj+MGzcOK1aswBdffIFOnTq5Ho+Li8OpU6dQUlJibK82/f9yc3NRXFyMSy+9FC1btkTLli2RnZ2N1157DS1btkRsbKzasB7i4+PRo0cP47Hu3bvjwIEDAOBqK93ff+zRRx/FlClTcPvtt6NXr164++67MXHiRGRmZgJQG3qrPu0VFxfn9uWK3377DceOHVObktOdkv3792P16tWuT0sA37Wj4zomISEh6NevH7KyslyP1dbWIisrCykpKU14ZM5lWRbGjRuHpUuXYs2aNUhKSjKe79evH1q1amW0aV5eHg4cOKA2/Z9rrrkG27dvx7Zt21w//fv3x+jRo13/rTa0N2jQILevqv/www/o0qULACApKQlxcXFGO5aVlWHDhg1qx/85ceIEgoPNt+YWLVqgtrYWgNrQW/Vpr5SUFJSUlCA3N9e1zZo1a1BbW4vk5OQ//Zid6nSnZM+ePfj8888RHR1tPO+zdjyLYt1Gt2jRIis0NNRasGCBtWvXLuv++++32rVrZxUWFjb1oTnSgw8+aEVGRlpffvmldfjwYdfPiRMnXNs88MADVmJiorVmzRpr8+bNVkpKipWSktKER+18Z34rx7LUhvWxceNGq2XLltazzz5r7dmzx3r//fetNm3aWO+9955rm+eee85q166d9fHHH1vfffedNXz4cCspKck6efJkEx65c6SlpVkdO3a0VqxYYe3bt89asmSJ1aFDB+uxxx5zbaM2NJWXl1tbt261tm7dagGwXn75ZWvr1q2ub4vUp72uv/56q2/fvtaGDRusdevWWRdccIF1xx13NNWf1CQ8teOpU6esYcOGWZ06dbK2bdtm/FtTVVXl2ocv2tGRHRPLsqzXX3/dSkxMtEJCQqyBAwda69evb+pDciwAdf7Mnz/ftc3Jkyethx56yGrfvr3Vpk0b6+abb7YOHz7cdAftB7hjojasn+XLl1s9e/a0QkNDrW7duln/+te/jOdra2utadOmWbGxsVZoaKh1zTXXWHl5eU10tM5TVlZmTZgwwUpMTLTCwsKs8847z3riiSeMN3+1oemLL76o8z0wLS3Nsqz6tdfRo0etO+64wwoPD7ciIiKse++91yovL2+Cv6bpeGrHffv2/eG/NV988YVrH75oxyDLOmM6QREREZEm5LgaExEREWm+1DERERERx1DHRERERBxDHRMRERFxDHVMRERExDHUMRERERHHUMdEREREHEMdExEREXEMdUxERETEMdQxEREREcdQx0REREQcQx0TERERcYz/B4TQTXmwSqJ3AAAAAElFTkSuQmCC"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# show images\n",
    "matplotlib_imshow(img_grid, one_channel=True)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "# write to tensorboard\n",
    "writer.add_image('four_fashion_mnist_images', img_grid)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 3. 使用 TensorBoard 检查模型\n",
    "TensorBoard’s 的优势之一是其能够可视化复杂模型 结构。让可视化我们构建的模型"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [],
   "source": [
    "\n",
    "writer.add_graph(net, images)\n",
    "writer.close()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 4. 添加 “Projector” 到 TensorBoard\n",
    "我们可以通过 add_embedding 方法"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [],
   "source": [
    "# helper function\n",
    "def select_n_random(data, labels, n=100):\n",
    "    '''\n",
    "    Selects n random datapoints and their corresponding labels from a dataset\n",
    "    '''\n",
    "    assert len(data) == len(labels)\n",
    "\n",
    "    perm = torch.randperm(len(data))\n",
    "    return data[perm][:n], labels[perm][:n]\n",
    "\n",
    "\n",
    "# select random images and their target indices\n",
    "images, labels = select_n_random(trainset.data, trainset.targets)\n",
    "\n",
    "# get the class labels for each image\n",
    "class_labels = [classes[lab] for lab in labels]\n",
    "\n",
    "# log embeddings\n",
    "features = images.view(-1, 28 * 28)\n",
    "writer.add_embedding(features,\n",
    "                     metadata=class_labels,\n",
    "                     label_img=images.unsqueeze(1))\n",
    "writer.close()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 5. 使用 TensorBoard 跟踪模型训练\n",
    "在前面的示例中，我们只是打印模型运行损失,每2000次迭代。现在，我们将运行损失记录到 TensorBoard，并通过 plot_classes_preds 函数查看模型 所做的预测。"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [],
   "source": [
    "\n",
    "# helper functions\n",
    "\n",
    "def images_to_probs(net, images):\n",
    "    '''\n",
    "    Generates predictions and corresponding probabilities from a trained\n",
    "    network and a list of images\n",
    "    '''\n",
    "    output = net(images)\n",
    "    # convert output probabilities to predicted class\n",
    "    _, preds_tensor = torch.max(output, 1)\n",
    "    preds = np.squeeze(preds_tensor.numpy())\n",
    "    return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]\n",
    "\n",
    "\n",
    "def plot_classes_preds(net, images, labels):\n",
    "    '''\n",
    "    Generates matplotlib Figure using a trained network, along with images\n",
    "    and labels from a batch, that shows the network's top prediction along\n",
    "    with its probability, alongside the actual label, coloring this\n",
    "    information based on whether the prediction was correct or not.\n",
    "    Uses the \"images_to_probs\" function.\n",
    "    '''\n",
    "    preds, probs = images_to_probs(net, images)\n",
    "    # plot the images in the batch, along with predicted and true labels\n",
    "    fig = plt.figure(figsize=(12, 48))\n",
    "    for idx in np.arange(4):\n",
    "        ax = fig.add_subplot(1, 4, idx + 1, xticks=[], yticks=[])\n",
    "        matplotlib_imshow(images[idx], one_channel=True)\n",
    "        ax.set_title(\"{0}, {1:.1f}%(label: {2})\".format(\n",
    "            classes[preds[idx]],\n",
    "            probs[idx] * 100.0,\n",
    "            classes[labels[idx]]),\n",
    "            color=(\"green\" if preds[idx] == labels[idx].item() else \"red\"))\n",
    "    return fig"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[20], line 4\u001B[0m\n\u001B[1;32m      1\u001B[0m running_loss \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0.0\u001B[39m\n\u001B[1;32m      2\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m epoch \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(\u001B[38;5;241m1\u001B[39m):  \u001B[38;5;66;03m# loop over the dataset multiple times\u001B[39;00m\n\u001B[0;32m----> 4\u001B[0m     \u001B[38;5;28;01mfor\u001B[39;00m i, data \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(trainloader, \u001B[38;5;241m0\u001B[39m):\n\u001B[1;32m      5\u001B[0m \n\u001B[1;32m      6\u001B[0m         \u001B[38;5;66;03m# get the inputs; data is a list of [inputs, labels]\u001B[39;00m\n\u001B[1;32m      7\u001B[0m         inputs, labels \u001B[38;5;241m=\u001B[39m data\n\u001B[1;32m      9\u001B[0m         \u001B[38;5;66;03m# zero the parameter gradients\u001B[39;00m\n",
      "File \u001B[0;32m~/Desktop/study4ai/study4ai_venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:631\u001B[0m, in \u001B[0;36m_BaseDataLoaderIter.__next__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m    628\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_sampler_iter \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m    629\u001B[0m     \u001B[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001B[39;00m\n\u001B[1;32m    630\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_reset()  \u001B[38;5;66;03m# type: ignore[call-arg]\u001B[39;00m\n\u001B[0;32m--> 631\u001B[0m data \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_next_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    632\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_num_yielded \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[1;32m    633\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_dataset_kind \u001B[38;5;241m==\u001B[39m _DatasetKind\u001B[38;5;241m.\u001B[39mIterable \u001B[38;5;129;01mand\u001B[39;00m \\\n\u001B[1;32m    634\u001B[0m         \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_IterableDataset_len_called \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m \\\n\u001B[1;32m    635\u001B[0m         \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_num_yielded \u001B[38;5;241m>\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_IterableDataset_len_called:\n",
      "File \u001B[0;32m~/Desktop/study4ai/study4ai_venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1329\u001B[0m, in \u001B[0;36m_MultiProcessingDataLoaderIter._next_data\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m   1326\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_process_data(data)\n\u001B[1;32m   1328\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_shutdown \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_tasks_outstanding \u001B[38;5;241m>\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m-> 1329\u001B[0m idx, data \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_get_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m   1330\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_tasks_outstanding \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[1;32m   1331\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_dataset_kind \u001B[38;5;241m==\u001B[39m _DatasetKind\u001B[38;5;241m.\u001B[39mIterable:\n\u001B[1;32m   1332\u001B[0m     \u001B[38;5;66;03m# Check for _IterableDatasetStopIteration\u001B[39;00m\n",
      "File \u001B[0;32m~/Desktop/study4ai/study4ai_venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1295\u001B[0m, in \u001B[0;36m_MultiProcessingDataLoaderIter._get_data\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m   1291\u001B[0m     \u001B[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001B[39;00m\n\u001B[1;32m   1292\u001B[0m     \u001B[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001B[39;00m\n\u001B[1;32m   1293\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m   1294\u001B[0m     \u001B[38;5;28;01mwhile\u001B[39;00m \u001B[38;5;28;01mTrue\u001B[39;00m:\n\u001B[0;32m-> 1295\u001B[0m         success, data \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_try_get_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m   1296\u001B[0m         \u001B[38;5;28;01mif\u001B[39;00m success:\n\u001B[1;32m   1297\u001B[0m             \u001B[38;5;28;01mreturn\u001B[39;00m data\n",
      "File \u001B[0;32m~/Desktop/study4ai/study4ai_venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1133\u001B[0m, in \u001B[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001B[0;34m(self, timeout)\u001B[0m\n\u001B[1;32m   1120\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m_try_get_data\u001B[39m(\u001B[38;5;28mself\u001B[39m, timeout\u001B[38;5;241m=\u001B[39m_utils\u001B[38;5;241m.\u001B[39mMP_STATUS_CHECK_INTERVAL):\n\u001B[1;32m   1121\u001B[0m     \u001B[38;5;66;03m# Tries to fetch data from `self._data_queue` once for a given timeout.\u001B[39;00m\n\u001B[1;32m   1122\u001B[0m     \u001B[38;5;66;03m# This can also be used as inner loop of fetching without timeout, with\u001B[39;00m\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m   1130\u001B[0m     \u001B[38;5;66;03m# Returns a 2-tuple:\u001B[39;00m\n\u001B[1;32m   1131\u001B[0m     \u001B[38;5;66;03m#   (bool: whether successfully get data, any: data if successful else None)\u001B[39;00m\n\u001B[1;32m   1132\u001B[0m     \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m-> 1133\u001B[0m         data \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_data_queue\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mget\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtimeout\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtimeout\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m   1134\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m (\u001B[38;5;28;01mTrue\u001B[39;00m, data)\n\u001B[1;32m   1135\u001B[0m     \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mException\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m e:\n\u001B[1;32m   1136\u001B[0m         \u001B[38;5;66;03m# At timeout and error, we manually check whether any worker has\u001B[39;00m\n\u001B[1;32m   1137\u001B[0m         \u001B[38;5;66;03m# failed. Note that this is the only mechanism for Windows to detect\u001B[39;00m\n\u001B[1;32m   1138\u001B[0m         \u001B[38;5;66;03m# worker failures.\u001B[39;00m\n",
      "File \u001B[0;32m~/opt/anaconda3/lib/python3.8/multiprocessing/queues.py:116\u001B[0m, in \u001B[0;36mQueue.get\u001B[0;34m(self, block, timeout)\u001B[0m\n\u001B[1;32m    114\u001B[0m         \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_rlock\u001B[38;5;241m.\u001B[39mrelease()\n\u001B[1;32m    115\u001B[0m \u001B[38;5;66;03m# unserialize the data after having released the lock\u001B[39;00m\n\u001B[0;32m--> 116\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43m_ForkingPickler\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mloads\u001B[49m\u001B[43m(\u001B[49m\u001B[43mres\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/Desktop/study4ai/study4ai_venv/lib/python3.8/site-packages/torch/multiprocessing/reductions.py:514\u001B[0m, in \u001B[0;36mrebuild_storage_filename\u001B[0;34m(cls, manager, handle, size, dtype)\u001B[0m\n\u001B[1;32m    512\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m storage\u001B[38;5;241m.\u001B[39m_shared_decref()\n\u001B[1;32m    513\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m dtype \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m--> 514\u001B[0m     storage \u001B[38;5;241m=\u001B[39m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mUntypedStorage\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_new_shared_filename_cpu\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmanager\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mhandle\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msize\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    515\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m    516\u001B[0m     byte_size \u001B[38;5;241m=\u001B[39m size \u001B[38;5;241m*\u001B[39m torch\u001B[38;5;241m.\u001B[39m_utils\u001B[38;5;241m.\u001B[39m_element_size(dtype)\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "running_loss = 0.0\n",
    "for epoch in range(1):  # loop over the dataset multiple times\n",
    "\n",
    "    for i, data in enumerate(trainloader, 0):\n",
    "\n",
    "        # get the inputs; data is a list of [inputs, labels]\n",
    "        inputs, labels = data\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "        if i % 1000 == 999:  # every 1000 mini-batches...\n",
    "\n",
    "            # ...log the running loss\n",
    "            writer.add_scalar('training loss',\n",
    "                              running_loss / 1000,\n",
    "                              epoch * len(trainloader) + i)\n",
    "\n",
    "            # ...log a Matplotlib Figure showing the model's predictions on a\n",
    "            # random mini-batch\n",
    "            writer.add_figure('predictions vs. actuals',\n",
    "                              plot_classes_preds(net, inputs, labels),\n",
    "                              global_step=epoch * len(trainloader) + i)\n",
    "            running_loss = 0.0\n",
    "print('Finished Training')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 6. 使用 TensorBoard 评估经过训练的模型\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 1. gets the probability predictions in a test_size x num_classes Tensor\n",
    "# 2. gets the preds in a test_size Tensor\n",
    "# takes ~10 seconds to run\n",
    "class_probs = []\n",
    "class_label = []\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        output = net(images)\n",
    "        class_probs_batch = [F.softmax(el, dim=0) for el in output]\n",
    "\n",
    "        class_probs.append(class_probs_batch)\n",
    "        class_label.append(labels)\n",
    "\n",
    "test_probs = torch.cat([torch.stack(batch) for batch in class_probs])\n",
    "test_label = torch.cat(class_label)\n",
    "\n",
    "\n",
    "# helper function\n",
    "def add_pr_curve_tensorboard(class_index, test_probs, test_label, global_step=0):\n",
    "    '''\n",
    "    Takes in a \"class_index\" from 0 to 9 and plots the corresponding\n",
    "    precision-recall curve\n",
    "    '''\n",
    "    tensorboard_truth = test_label == class_index\n",
    "    tensorboard_probs = test_probs[:, class_index]\n",
    "\n",
    "    writer.add_pr_curve(classes[class_index],\n",
    "                        tensorboard_truth,\n",
    "                        tensorboard_probs,\n",
    "                        global_step=global_step)\n",
    "    writer.close()\n",
    "\n",
    "\n",
    "# plot all the pr curves\n",
    "for i in range(len(classes)):\n",
    "    add_pr_curve_tensorboard(i, test_probs, test_label)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}